前情提要: 昨天已經完成model.py, dataloader.py, 我自己習慣分成這兩個檔案,因為等到之後model的code越來越多,已經不適合跟train_step這些寫在一起。
這裡一樣有幾項固定的東東:
這些名稱都是固定的哦
from torch.utils.data import DataLoader
import lightning as pl
from model import MNISTClassifier
from dataloader import CustomDataset
class example(pl.LightningModule):
def __init__(
self,
batch_size = 16,
train_txt = "/ws/code/Day8/train.txt",
val_txt = "/ws/code/Day8/test.txt",
):
super().__init__()
self.batch_size = batch_size
self.train_dataset = CustomDataset(train_txt)
self.val_dataset = CustomDataset(val_txt)
self.model = MNISTClassifier()
def forward(self, batch):
pass
def training_step(self, batch, batch_idx):
pass
def validation_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle = True,
drop_last = True,
num_workers = 4,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size = self.batch_size,
shuffle = False,
drop_last = True,
num_workers = 4,
)
這裡batch就要回到我們之前的dataloader囉,我們回傳兩個東西,image, label,那我自己在step裡面會x, y表示,比較簡單。
這就是我之前所說的會寫成一個個function,簡單明瞭,然後你會發現好像少了pytorch常寫的loss.backward()…,主要是人家training_step都幫你包好好了,會自己去做backward等等,當然也可以手動更新。
流程如下:
def training_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_fn(preds, y).mean()
self.log("train_loss", loss.item(), prog_bar = True)
return loss
基本上前半段跟training_step一樣,那在這邊我自己習慣多metrics,也就是來評估模型訓練好不好,以下是各類評估常用到的,可以根據自己的任務選擇。
這裡我們就選最簡單accuracy來實作,我們來透過使用torchmetrics這個包來計算吧~~
這裡的torchmetrics寫法有兩種,我選擇用第二種來做,也就是在validation_step update,on_validation_epoch_end來compute,可以把它想像成validation每跑一個batch就透過update更新,等到整個epoch跑完透過compute算出最後結果,然後紀錄在log並且reset用於下一個epoch計算。
from torch.utils.data import DataLoader
import torchmetrics
import lightning as pl
from model import MNISTClassifier
from dataloader import CustomDataset
class example(pl.LightningModule):
def __init__(
self,
batch_size = 16,
train_txt = "/ws/code/Day8/train.txt",
val_txt = "/ws/code/Day8/test.txt",
):
super().__init__()
self.batch_size = batch_size
self.train_dataset = CustomDataset(train_txt)
self.val_dataset = CustomDataset(val_txt)
self.model = MNISTClassifier()
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes = 10)
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_fn(preds, y).mean()
self.log("val/loss", loss.item(), prog_bar = True)
self.valid_acc.update(preds, y)
def on_validation_epoch_end(self):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()
這邊就是設定optimizer跟lr_scheduler,簡單一點的就是固定learning rate,也可以透過scheduler來調整。
這裡的self.parameters()會去抓所有可更新的參數,在這邊就是self.model。
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr = 1e-3)
# lr_scheduler
return optimizer # [optimizer], [lr_scheduler]
今天就更新到這囉,可以消化一下。
明天把後續更新完